# ======================================================================
# Copyright CERFACS (March 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

import typing as ty

import numpy
from qat.lang.AQASM.program import Program
from qat.lang.AQASM.routines import QRoutine
from qaths.utils.qpus import GenericLinalgSimulator


def _get_circuit(routine: QRoutine, **to_circ_kwargs):
    prog = Program()
    qbits = prog.qalloc(routine.arity)
    prog.apply(routine, qbits)
    return prog.to_circ(**to_circ_kwargs)


def _get_state_without_ancilla(state: int, result_size: int, ancilla_size: int) -> int:
    binary_state = bin(state)[2:].zfill(result_size + ancilla_size)
    return int(binary_state[:result_size], 2)


def simulate_routine(
    routine: QRoutine, probability_threshold: float = 1e-10, **to_circ_kwargs
) -> numpy.ndarray:
    r"""Simulate the given routine with the given starting state.

    :param routine: The routine to simulate.
    :param amplitude_threshold: Every amplitude with an amplitude lower than this
        threshold will be considered as 0.
    :param linking_set: gate implementations to use as a list of AbstractGates with
        implementation.
    :return: The end state after the routine application.
    """

    circuit = _get_circuit(routine, **to_circ_kwargs)

    linalg_simulator = GenericLinalgSimulator()
    job = circuit.to_job()
    result = linalg_simulator.submit_job(job)
    output_state = numpy.zeros((2 ** routine.arity,), dtype=numpy.complex)
    circuit_ancillas = circuit.nbqbits - routine.arity

    for sample in result.raw_data:
        amplitude = sample.amplitude.re + 1.0j * sample.amplitude.im
        if sample.probability > probability_threshold:
            state = _get_state_without_ancilla(
                sample.state, routine.arity, circuit_ancillas
            )
            output_state[state] = amplitude
    return output_state


# _OP_MAP = {"C": ctrl, "D": dag}
#
# _P0 = numpy.array([[1, 0], [0, 0]], dtype=numpy.complex)
# _P1 = numpy.array([[0, 0], [0, 1]], dtype=numpy.complex)
# _ID = numpy.array([[1, 0], [0, 1]], dtype=numpy.complex)
#
#
# def get_matrix(
#     circuit, index: int, gate_set=None, endianness=BIG_ENDIAN
# ) -> numpy.ndarray:
#     """
#     Returns the matrix corresponding to the gate at index `index` in a circuit and the
#     target qubits of the gate.
#
#     :param circuit: The circuit containing the gate to translate.
#     :param index: The index of the gate to translate in the given `circuit`.
#     :param gate_set: The gate set considered. The matrix name should appear in the gate
#            set (if not, might raise an UnknownGate exception).
#     :param endianness: The endianness of the returned matrix.
#     """
#     if gate_set is None:
#         gate_set = default_gate_set()
#
#     op = circuit.ops[index]
#     name, parameters = extract_syntax(circuit.gateDic[op.gate], circuit.gateDic)
#     name = name.split("-")
#     root_name = name[-1]
#     operators = name[0:-1]
#
#     if root_name.lower().endswith("cnot"):
#         while root_name[0].lower() == "c":
#             operators.insert(0, "C")
#             root_name = root_name[1:]
#         root_name = "X"
#
#     matrix = numpy.array(gate_set[root_name].matrix_generator(*parameters))
#     trgt = op.qbits[-1]
#     ctrls = sorted(op.qbits[:-1])
#     # If the operation is not a controlled by any qubit then we can simplify
#     # greatly the algorithm.
#     if not ctrls:
#         ret = 1
#         for qubit_index in range(circuit.nbqbits):
#             # If we are on the target qubit then apply the gate.
#             if qubit_index == trgt:
#                 ret = kron(ret, matrix)
#             # Else, we should multiply by the gate that is controlled.
#             else:
#                 ret = kron(ret, _ID)
#         if endianness == LITTLE_ENDIAN:
#             reverse_endian(ret)
#         return ret
#
#     # Else, we have control qubits.
#     ret = numpy.zeros((2 ** circuit.nbqbits, 2 ** circuit.nbqbits), dtype=numpy.complex)
#     # For each possible values for the control qubits.
#     for ctrl_values in range(2 ** len(ctrls)):
#         ctrl_values_list = [(ctrl_values >> k) & 1 for k in range(len(ctrls))]
#         current_control_index = 0
#         current_matrix = 1
#         for qubit_index in range(circuit.nbqbits):
#             # If we are on the target qubit then...
#             if qubit_index == trgt:
#                 # If all the control qubits are 1, then multiply by the
#                 # gate.
#                 if all(ctrl_values_list):
#                     current_matrix = kron(current_matrix, matrix)
#                 # Else, we should multiply by the gate that is controlled.
#                 else:
#                     current_matrix = kron(current_matrix, _ID)
#             # Else if we are on a control qubit, determine if we should
#             # use P0 or P1 and apply it.
#             elif qubit_index in ctrls:
#                 current_matrix = kron(
#                     current_matrix,
#                     _P1 if ctrl_values_list[current_control_index] else _P0,
#                 )
#                 current_control_index += 1
#             # Else, the current qubit do nothing.
#             else:
#                 current_matrix = kron(current_matrix, _ID)
#         ret += current_matrix
#     if endianness == LITTLE_ENDIAN:
#         reverse_endian(ret)
#     return ret
#
#
# def get_matrices(circuit, gate_set=None, reverse: bool = False):
#     """
#     Iterates over the gates matrices and qubits of a circuit.
#     If no gate set is provided, uses the default gate set of pyAQASM to
#     generate the matrices.
#     """
#     indices = (
#         range(len(circuit.ops)) if not reverse else reversed(range(len(circuit.ops)))
#     )
#
#     def _get_matrix_from_index(i: int) -> numpy.ndarray:
#         return get_matrix(circuit, i, gate_set)
#
#     max_workers = 4
#     with ThreadPoolExecutor(max_workers=max_workers) as pool:
#         yield from pool.map(_get_matrix_from_index, indices)
#
#
# def chain_product(matrices) -> numpy.ndarray:
#     if not isinstance(matrices, collections.Iterator):
#         matrices = iter(matrices)
#     ret = next(matrices)
#     for i, mat in enumerate(matrices):
#         ret = numpy.dot(ret, mat)
#     return ret
#
#
# def routine2unitary(routine: QRoutine) -> numpy.ndarray:
#     """Computes the unitary matrix representing the routine given in parameter.
#
#     :param routine: A QRoutine with a reasonable arity (the resulting matrix
#         should be storable in RAM).
#     :return: the unitary matrix representing the given routine.
#     """
#     prog = Program()
#     qbits = prog.qalloc(routine.arity)
#     prog.apply(routine, qbits)
#     circ = prog.to_circ()
#     res = chain_product(get_matrices(circ, reverse=True))
#     return res
